{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": true
   },
   "source": [
    "# An Overview of Quantized Activations"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In this second tutorial, we take a deeper look at quantized activation.  \n",
    "We were already introduced to quantized activations in the previous tutorial, when we looked at input and output quantization of `QuantConv2d` with the `Int8ActPerTensorFloat` quantizer. The same result can be obtained with different syntax by coupling `QuantConv2d` with `QuantIdentity` layers, which by default uses the `Int8ActPerTensorFloat` quantizer.\n",
    "As an example, we compare - on the *same input* - the result of `QuantConv2d` with `output_quant` enabled with the result of a `QuantConv2d` followed by a `QuantIdentity`:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "True"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import torch\n",
    "from brevitas.nn import QuantConv2d, QuantIdentity\n",
    "from brevitas.quant.scaled_int import Int8ActPerTensorFloat \n",
    "\n",
    "torch.manual_seed(0)\n",
    "output_quant_conv = QuantConv2d(\n",
    "    in_channels=2, out_channels=3, kernel_size=(3,3), output_quant=Int8ActPerTensorFloat)\n",
    "\n",
    "torch.manual_seed(0)\n",
    "default_quant_conv = QuantConv2d(\n",
    "    in_channels=2, out_channels=3, kernel_size=(3,3))\n",
    "output_identity_quant = QuantIdentity()\n",
    "\n",
    "inp = torch.randn(1, 2, 5, 5)\n",
    "out_tensor1 = output_quant_conv(inp)\n",
    "out_tensor2 = output_identity_quant(default_quant_conv(inp))\n",
    "\n",
    "assert out_tensor1.isclose(out_tensor2).all().item()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can observe a similar behaviour if we enable input quantization too:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "True"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.manual_seed(0)\n",
    "input_output_quant_conv = QuantConv2d(\n",
    "    in_channels=2, out_channels=3, kernel_size=(3,3), \n",
    "    input_quant=Int8ActPerTensorFloat, output_quant=Int8ActPerTensorFloat)\n",
    "\n",
    "torch.manual_seed(0)\n",
    "default_quant_conv = QuantConv2d(\n",
    "    in_channels=2, out_channels=3, kernel_size=(3,3))\n",
    "input_identity_quant = QuantIdentity()\n",
    "output_identity_quant = QuantIdentity()\n",
    "\n",
    "inp = torch.randn(1, 2, 5, 5)\n",
    "out_tensor1 = input_output_quant_conv(inp)\n",
    "out_tensor2 = output_identity_quant(default_quant_conv(input_identity_quant(inp)))\n",
    "\n",
    "assert out_tensor1.isclose(out_tensor2).all().item()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "From an algorithmic point of view then the two different implementation are doing the same thing. However, as it will become clearer in later tutorials, there are currently some scenarios where picking one style over the other can make a difference when it comes to exporting to a format such as standard ONNX. In the meantime, we can just keep in mind that both alternatives exist."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "As it was the case with `QuantConv2d`, when we disable quantization of an activation, the layer behaves as its floating-point variant. In the case of `QuantIdentity`, that means behaving like an identity function:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "True"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "disabled_quant_identity = QuantIdentity(act_quant=None)\n",
    "(inp == disabled_quant_identity(inp)).all().item()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Again, as it was the case for `QuantConv2d`, quantized activation layers can also return a `QuantTensor`:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "QuantTensor(value=tensor([[[[-0.4566, -0.5707, -0.5517,  0.5897,  1.5409],\n",
       "          [ 0.5136, -0.5897, -0.5707,  0.1902, -0.0761],\n",
       "          [-0.4946, -1.5029, -0.1902,  0.4376,  1.3317],\n",
       "          [-1.6361,  2.0736,  1.7122,  2.3780, -1.1224],\n",
       "          [-0.3234, -1.0844, -0.0761, -0.0951, -0.7610]],\n",
       "\n",
       "         [[-1.5980,  0.0190, -0.7419,  0.1902,  0.6278],\n",
       "          [ 0.6468, -0.2473, -0.5327,  1.1605,  0.4376],\n",
       "          [-0.7990, -1.2936, -0.7419, -1.3127, -0.2283],\n",
       "          [-2.4351, -0.0761,  0.2283,  0.7990, -0.1902],\n",
       "          [-0.3615, -1.2175, -0.6278, -0.4566,  1.9214]]]],\n",
       "       grad_fn=<MulBackward0>), scale=tensor(0.0190, grad_fn=<DivBackward0>), zero_point=tensor(0.), bit_width=tensor(8.), signed_t=tensor(True), training_t=tensor(True))"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "return_quant_identity = QuantIdentity(return_quant_tensor=True)\n",
    "out_tensor = return_quant_identity(inp)\n",
    "out_tensor"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "True"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "assert out_tensor.is_valid"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "As expected, a `QuantIdentity` with quantization disabled behaves like an identity function also when a `QuantTensor` is passed in. However, depending on whather `return_quant_tensor` is set to `False` or not, quantization metadata might be stripped out, i.e. the input `QuantTensor` is going to be returned as an implicitly quantized `torch.Tensor`:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[[[-0.4566, -0.5707, -0.5517,  0.5897,  1.5409],\n",
       "          [ 0.5136, -0.5897, -0.5707,  0.1902, -0.0761],\n",
       "          [-0.4946, -1.5029, -0.1902,  0.4376,  1.3317],\n",
       "          [-1.6361,  2.0736,  1.7122,  2.3780, -1.1224],\n",
       "          [-0.3234, -1.0844, -0.0761, -0.0951, -0.7610]],\n",
       "\n",
       "         [[-1.5980,  0.0190, -0.7419,  0.1902,  0.6278],\n",
       "          [ 0.6468, -0.2473, -0.5327,  1.1605,  0.4376],\n",
       "          [-0.7990, -1.2936, -0.7419, -1.3127, -0.2283],\n",
       "          [-2.4351, -0.0761,  0.2283,  0.7990, -0.1902],\n",
       "          [-0.3615, -1.2175, -0.6278, -0.4566,  1.9214]]]],\n",
       "       grad_fn=<MulBackward0>)"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "out_torch_tensor = disabled_quant_identity(out_tensor)\n",
    "out_torch_tensor"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "QuantTensor(value=tensor([[[[-0.4566, -0.5707, -0.5517,  0.5897,  1.5409],\n",
       "          [ 0.5136, -0.5897, -0.5707,  0.1902, -0.0761],\n",
       "          [-0.4946, -1.5029, -0.1902,  0.4376,  1.3317],\n",
       "          [-1.6361,  2.0736,  1.7122,  2.3780, -1.1224],\n",
       "          [-0.3234, -1.0844, -0.0761, -0.0951, -0.7610]],\n",
       "\n",
       "         [[-1.5980,  0.0190, -0.7419,  0.1902,  0.6278],\n",
       "          [ 0.6468, -0.2473, -0.5327,  1.1605,  0.4376],\n",
       "          [-0.7990, -1.2936, -0.7419, -1.3127, -0.2283],\n",
       "          [-2.4351, -0.0761,  0.2283,  0.7990, -0.1902],\n",
       "          [-0.3615, -1.2175, -0.6278, -0.4566,  1.9214]]]],\n",
       "       grad_fn=<MulBackward0>), scale=tensor(0.0190, grad_fn=<DivBackward0>), zero_point=tensor(0.), bit_width=tensor(8.), signed_t=tensor(True), training_t=tensor(True))"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "return_disabled_quant_identity = QuantIdentity(act_quant=None, return_quant_tensor=True)\n",
    "identity_out_tensor = return_disabled_quant_identity(out_tensor)\n",
    "identity_out_tensor"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Moving on from `QuantIdentity`, let's take a look at `QuantReLU`. Anything we said so far about `QuantIdentity` also applies to `QuantReLU`. The difference though is that `QuantReLU` implements a ReLU function followed by quantization, while `QuantIdentity` is really just the quantization operator. Additionally, by default `QuantReLU` adopts the `Uint8ActPerTensorFloat`, meaning that the output of quantization is *unsigned*:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "QuantTensor(value=tensor([[[[0.0000, 0.0000, 0.0000, 0.5974, 1.5402],\n",
       "          [0.5041, 0.0000, 0.0000, 0.1867, 0.0000],\n",
       "          [0.0000, 0.0000, 0.0000, 0.4481, 1.3255],\n",
       "          [0.0000, 2.0817, 1.7083, 2.3804, 0.0000],\n",
       "          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000]],\n",
       "\n",
       "         [[0.0000, 0.0187, 0.0000, 0.1867, 0.6254],\n",
       "          [0.6348, 0.0000, 0.0000, 1.1668, 0.4387],\n",
       "          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000],\n",
       "          [0.0000, 0.0000, 0.2334, 0.7935, 0.0000],\n",
       "          [0.0000, 0.0000, 0.0000, 0.0000, 1.9230]]]], grad_fn=<MulBackward0>), scale=tensor(0.0093, grad_fn=<DivBackward0>), zero_point=tensor(0.), bit_width=tensor(8.), signed_t=tensor(False), training_t=tensor(True))"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from brevitas.nn import QuantReLU\n",
    "\n",
    "return_quant_relu = QuantReLU(return_quant_tensor=True)\n",
    "return_quant_relu(inp)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "`QuantReLU`, like `QuantIdentity`, is also special compared to other non-linear quantized activation layers as it preserves the metadata of an input `QuantTensor` even when quantization is disabled:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "return_disabled_quant_relu = QuantReLU(act_quant=None, return_quant_tensor=True)\n",
    "relu_out_tensor = return_disabled_quant_relu(out_tensor)\n",
    "assert relu_out_tensor.is_valid==True\n",
    "assert relu_out_tensor.scale == out_tensor.scale\n",
    "assert relu_out_tensor.zero_point == out_tensor.zero_point\n",
    "assert relu_out_tensor.bit_width == out_tensor.bit_width"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "That doesn't apply to other layers like, say, `QuantSigmoid`:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "QuantTensor(value=(tensor([[[[0.3878, 0.3611, 0.3655, 0.6433, 0.8236],\n",
       "          [0.6257, 0.3567, 0.3611, 0.5474, 0.4810],\n",
       "          [0.3788, 0.1820, 0.4526, 0.6077, 0.7911],\n",
       "          [0.1630, 0.8883, 0.8471, 0.9151, 0.2456],\n",
       "          [0.4198, 0.2527, 0.4810, 0.4762, 0.3184]],\n",
       "\n",
       "         [[0.1683, 0.5048, 0.3226, 0.5474, 0.6520],\n",
       "          [0.6563, 0.4385, 0.3699, 0.7614, 0.6077],\n",
       "          [0.3102, 0.2152, 0.3226, 0.2120, 0.4432],\n",
       "          [0.0805, 0.4810, 0.5568, 0.6898, 0.4526],\n",
       "          [0.4106, 0.2284, 0.3480, 0.3878, 0.8723]]]],\n",
       "       grad_fn=<SigmoidBackward0>), None, None, None), scale=None, zero_point=None, bit_width=None, signed_t=None, training_t=tensor(True))"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from brevitas.nn import QuantSigmoid\n",
    "\n",
    "return_disabled_quant_sigmoid = QuantSigmoid(act_quant=None, return_quant_tensor=True)\n",
    "sigmoid_out_tensor = return_disabled_quant_sigmoid(out_tensor)\n",
    "sigmoid_out_tensor"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "False"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "assert not sigmoid_out_tensor.is_valid"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Something to always keep in mind is that the non-linearity of a quantized activation layer is always called on the *dequantized* representation of the input.\n",
    "For example, let's say we first quantize a floating-point `torch.Tensor` with an unsigned shifted quantizer such as `ShiftedUint8ActPerTensorFloat`, i.e. with zero-point such that the integer representation of its output is non-negative.\n",
    "Then, we pass this tensor as input to a `QuantReLU` with quantization *disabled*. The fact that the input to `QuantReLU` in its integer form is unsigned doesn't mean `QuantReLU` won't have any effect, as ReLU is called on the dequantized representation, which includes both *positive* and *negative* values:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "QuantTensor(value=tensor([[[[0.0000, 0.0000, 0.0000, 0.5854, 1.5485],\n",
       "          [0.5099, 0.0000, 0.0000, 0.1888, 0.0000],\n",
       "          [0.0000, 0.0000, 0.0000, 0.4532, 1.3219],\n",
       "          [0.0000, 2.0772, 1.6996, 2.3794, 0.0000],\n",
       "          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000]],\n",
       "\n",
       "         [[0.0000, 0.0189, 0.0000, 0.1888, 0.6232],\n",
       "          [0.6421, 0.0000, 0.0000, 1.1708, 0.4343],\n",
       "          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000],\n",
       "          [0.0000, 0.0000, 0.2266, 0.7931, 0.0000],\n",
       "          [0.0000, 0.0000, 0.0000, 0.0000, 1.9262]]]], grad_fn=<ReluBackward0>), scale=tensor(0.0189, grad_fn=<DivBackward0>), zero_point=tensor(129., grad_fn=<SWhereBackward0>), bit_width=tensor(8.), signed_t=tensor(False), training_t=tensor(True))"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from brevitas.quant.shifted_scaled_int import ShiftedUint8ActPerTensorFloat\n",
    "\n",
    "shifted_quant_identity = QuantIdentity(act_quant=ShiftedUint8ActPerTensorFloat, return_quant_tensor=True)\n",
    "return_disabled_quant_relu = QuantReLU(act_quant=None, return_quant_tensor=True)\n",
    "return_disabled_quant_relu(shifted_quant_identity(inp))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's now consider the very common scenario of a `QuantConv2d` followed by a `ReLU` or `QuantReLU`.\n",
    "In particular, let's say we have a `QuantConv2d` with output quantization *enabled* followed by a `ReLU`:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[[[0.0000, 0.0000, 0.0000],\n",
       "          [1.3134, 1.2557, 1.0392],\n",
       "          [0.4186, 0.0000, 0.0000]],\n",
       "\n",
       "         [[0.7361, 0.5340, 0.8516],\n",
       "          [0.2887, 0.3175, 0.0000],\n",
       "          [0.8949, 1.6743, 0.0722]],\n",
       "\n",
       "         [[0.0000, 0.0000, 0.0289],\n",
       "          [0.0000, 0.0000, 0.2021],\n",
       "          [0.0000, 0.0000, 0.4907]]]], grad_fn=<ReluBackward0>)"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.manual_seed(0)\n",
    "output_quant_conv = QuantConv2d(\n",
    "    in_channels=2, out_channels=3, kernel_size=(3,3), output_quant=Int8ActPerTensorFloat)\n",
    "torch.relu(output_quant_conv(inp))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We compare it against a `QuantConv2d` with default settings (i.e. output quantization *disabled*), followed by a `QuantReLU` with default settings (i.e. activation quantization *enabled*):"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[[[0.0000, 0.0000, 0.0000],\n",
       "          [1.3078, 1.2555, 1.0397],\n",
       "          [0.4185, 0.0000, 0.0000]],\n",
       "\n",
       "         [[0.7454, 0.5427, 0.8566],\n",
       "          [0.2943, 0.3269, 0.0000],\n",
       "          [0.8893, 1.6674, 0.0785]],\n",
       "\n",
       "         [[0.0065, 0.0000, 0.0262],\n",
       "          [0.0000, 0.0000, 0.1962],\n",
       "          [0.0000, 0.0000, 0.4839]]]], grad_fn=<MulBackward0>)"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.manual_seed(0)\n",
    "default_quant_conv = QuantConv2d(\n",
    "    in_channels=2, out_channels=3, kernel_size=(3,3))\n",
    "default_quant_relu = QuantReLU()\n",
    "default_quant_relu(default_quant_conv(inp))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can see the results are close but not quite the same.  \n",
    "In the first case, we quantized the output of `QuantConv2d` with an 8-bit signed quantizer, and then we passed it through a `ReLU`, meaning that half of the numerical range covered by the signed quantizer is now lost, and by all practical means the output can now be treated as a 7-bit unsigned number (although it's not explicitly marked as such).\n",
    "In the second case, we perform unsigned 8-bit quantization after `ReLU`. Because the range covered by the quantizer now includes only non-negative numbers, we don't waste a bit as in the previous case."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Regarding some premade activation quantizers, such as `Uint8ActPerTensorFloat`, `ShiftedUint8ActPerTensorFloat`, and `Int8ActPerTensorFloat`, a word of caution that anticipates some of the themes of the next tutorial.\n",
    "To minimize user interaction, Brevitas initializes scale and zero-point by collecting statistics for a number of training steps (by default 30). This can be seen as a sort of very basic calibration step, although it typically happens during training and with quantization already enabled. These statistics are accumulated in an exponential moving average that at end of the collection phase is used to initialize a learned *parameter*.\n",
    "During the collection phase then, the quantizer behaves differently between `train()` and `eval()` mode. In `train()` mode, the statistics for that particular batch are returned. In `eval()` mode, the exponential moving average is returned. After the collection phase is over the learned parameter is returned in both execution modes.\n",
    "We can easily observe this behaviour with an example. Let's first define a quantized activation and two random input tensors:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "quant_identity = QuantIdentity(return_quant_tensor=True)\n",
    "inp1 = torch.randn(3, 3)\n",
    "inp2 = torch.randn(3, 3)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We then compare the output scale factor of the two tensors between `train()` and `eval()` mode. The ones in train mode in general are different. The ones in eval mode are the same."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "False"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "out1_train = quant_identity(inp1)\n",
    "out2_train = quant_identity(inp2)\n",
    "assert not out1_train.scale.isclose(out2_train.scale).item()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "True"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "quant_identity.eval()\n",
    "out1_eval = quant_identity(inp1)\n",
    "out2_eval = quant_identity(inp2)\n",
    "assert out1_eval.scale.isclose(out2_eval.scale).item()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "By default, the only layer that is an exception to this is `QuantHardTanh`. That is because the interface to `torch.nn.HardTanh` already requires users to manually specify `min_val` and `max_val`, so Brevitas preserves that both when quantization is enabled or disabled. With quantization enabled, by default those values are used for initialization, but then the range is learned. Let's look at an example:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {
    "tags": [
     "raises-exception"
    ]
   },
   "outputs": [
    {
     "ename": "DependencyError",
     "evalue": "'Int8ActPerTensorFloatMinMaxInit' can not resolve attribute 'max_val' while building 'scaling_init_impl'",
     "output_type": "error",
     "traceback": [
      "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[1;31mDependencyError\u001b[0m                           Traceback (most recent call last)",
      "\u001b[1;32m<ipython-input-18-8145d2f87fcb>\u001b[0m in \u001b[0;36m<module>\u001b[1;34m\u001b[0m\n\u001b[0;32m      1\u001b[0m \u001b[1;32mfrom\u001b[0m \u001b[0mbrevitas\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mnn\u001b[0m \u001b[1;32mimport\u001b[0m \u001b[0mQuantHardTanh\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m      2\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m----> 3\u001b[1;33m \u001b[0mQuantHardTanh\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m",
      "\u001b[1;32mc:\\brevitas_fx\\src\\brevitas\\nn\\quant_activation.py\u001b[0m in \u001b[0;36m__init__\u001b[1;34m(self, act_quant, input_quant, return_quant_tensor, **kwargs)\u001b[0m\n\u001b[0;32m    117\u001b[0m             \u001b[0mact_quant\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mact_quant\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    118\u001b[0m             \u001b[0mreturn_quant_tensor\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mreturn_quant_tensor\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 119\u001b[1;33m             **kwargs)\n\u001b[0m\u001b[0;32m    120\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    121\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32mc:\\brevitas_fx\\src\\brevitas\\nn\\quant_layer.py\u001b[0m in \u001b[0;36m__init__\u001b[1;34m(self, act_impl, passthrough_act, input_quant, act_quant, return_quant_tensor, **kwargs)\u001b[0m\n\u001b[0;32m     77\u001b[0m             \u001b[0mpassthrough_act\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     78\u001b[0m             \u001b[0mact_quant\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 79\u001b[1;33m             **kwargs)\n\u001b[0m\u001b[0;32m     80\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     81\u001b[0m     \u001b[1;33m@\u001b[0m\u001b[0mproperty\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32mc:\\brevitas_fx\\src\\brevitas\\nn\\mixin\\act.py\u001b[0m in \u001b[0;36m__init__\u001b[1;34m(self, act_impl, passthrough_act, act_quant, **kwargs)\u001b[0m\n\u001b[0;32m    157\u001b[0m             \u001b[0mproxy_prefix\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;34m'act_'\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    158\u001b[0m             \u001b[0mkwargs_prefix\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;34m''\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 159\u001b[1;33m             **kwargs)\n\u001b[0m\u001b[0;32m    160\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    161\u001b[0m     \u001b[1;33m@\u001b[0m\u001b[0mproperty\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32mc:\\brevitas_fx\\src\\brevitas\\nn\\mixin\\base.py\u001b[0m in \u001b[0;36m__init__\u001b[1;34m(self, quant, proxy_protocol, none_quant_injector, proxy_prefix, kwargs_prefix, **kwargs)\u001b[0m\n\u001b[0;32m     98\u001b[0m             \u001b[0mquant_injector\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mquant\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     99\u001b[0m             \u001b[0mquant_injector\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mquant_injector\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mlet\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m**\u001b[0m\u001b[0mfilter_kwargs\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mkwargs_prefix\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 100\u001b[1;33m             \u001b[0mquant\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mquant_injector\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mproxy_class\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mquant_injector\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m    101\u001b[0m         \u001b[1;32melse\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    102\u001b[0m             \u001b[1;32mif\u001b[0m \u001b[1;32mnot\u001b[0m \u001b[0misinstance\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mquant\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mproxy_protocol\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32mc:\\brevitas_fx\\src\\brevitas\\proxy\\runtime_quant.py\u001b[0m in \u001b[0;36m__init__\u001b[1;34m(self, quant_layer, quant_injector)\u001b[0m\n\u001b[0;32m    108\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    109\u001b[0m     \u001b[1;32mdef\u001b[0m \u001b[0m__init__\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mquant_layer\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mquant_injector\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 110\u001b[1;33m         \u001b[0msuper\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mActQuantProxyFromInjector\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m__init__\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mquant_layer\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mquant_injector\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m    111\u001b[0m         \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mis_passthrough_act\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0m_is_passthrough_act\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mquant_injector\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    112\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32mc:\\brevitas_fx\\src\\brevitas\\proxy\\quant_proxy.py\u001b[0m in \u001b[0;36m__init__\u001b[1;34m(self, quant_layer, quant_injector, export_mode, export_handler)\u001b[0m\n\u001b[0;32m     74\u001b[0m         \u001b[1;31m# Use a normal list and not a ModuleList since this is a pointer to parent modules\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     75\u001b[0m         \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mtracked_module_list\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;33m[\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 76\u001b[1;33m         \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0madd_tracked_module\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mquant_layer\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m     77\u001b[0m         \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mexport_handler\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mexport_handler\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     78\u001b[0m         \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mexport_mode\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mexport_mode\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32mc:\\brevitas_fx\\src\\brevitas\\proxy\\quant_proxy.py\u001b[0m in \u001b[0;36madd_tracked_module\u001b[1;34m(self, module)\u001b[0m\n\u001b[0;32m    130\u001b[0m             \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mtracked_module_list\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mmodule\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    131\u001b[0m             \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mupdate_tracked_modules\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 132\u001b[1;33m             \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0minit_tensor_quant\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m    133\u001b[0m         \u001b[1;32melse\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    134\u001b[0m             \u001b[1;32mraise\u001b[0m \u001b[0mRuntimeError\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;34m\"Trying to add None as a parent module.\"\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32mc:\\brevitas_fx\\src\\brevitas\\proxy\\runtime_quant.py\u001b[0m in \u001b[0;36minit_tensor_quant\u001b[1;34m(self)\u001b[0m\n\u001b[0;32m    120\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    121\u001b[0m     \u001b[1;32mdef\u001b[0m \u001b[0minit_tensor_quant\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 122\u001b[1;33m         \u001b[0mtensor_quant\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mquant_injector\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mtensor_quant\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m    123\u001b[0m         \u001b[0mact_impl\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mquant_injector\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mact_impl\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    124\u001b[0m         \u001b[0mis_act_enabled\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0m_is_act_enabled\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mact_impl\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mtensor_quant\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "    \u001b[1;31m[... skipping hidden 1 frame]\u001b[0m\n",
      "\u001b[1;31mDependencyError\u001b[0m: 'Int8ActPerTensorFloatMinMaxInit' can not resolve attribute 'max_val' while building 'scaling_init_impl'"
     ]
    }
   ],
   "source": [
    "from brevitas.nn import QuantHardTanh\n",
    "\n",
    "QuantHardTanh()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "As expected, we get an error concering a missing `max_val` attribute. Let's try to pass it then, together with `min_val`:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "quant_hard_tanh = QuantHardTanh(max_val=1.0, min_val=-1.0, return_quant_tensor=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The layer is now correctly initialized. We can see that the output scale factors are all the same between `train()` and `eval()` mode:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "True"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "out1_train = quant_hard_tanh(inp1)\n",
    "quant_hard_tanh.eval()\n",
    "out2_eval = quant_hard_tanh(inp2)\n",
    "assert out1_train.scale.isclose(out2_eval.scale).item()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Finally, a reminder that mixing things up is perfectly legal and encouraged in Brevitas.\n",
    "For example, a `QuantIdentity` with `act_quant=Int8ActPerTensorFloatMinMaxInit` is equivalent to a default `QuantHardTanh`, or conversely a `QuantHardTanh` with `act_quant=Int8ActPerTensorFloat` is equivalent to a default `QuantIdentity`. This is allowed by the fact that - as it will be explained in the next tutorial - the same layer can accept different keyword arguments when different quantizers are set. So a QuantIdentity with `act_quant=Int8ActPerTensorFloatMinMaxInit` is going to expect arguments `min_val` and `max_val` the same way a default `QuantHardTanh` would."
   ]
  }
 ],
 "metadata": {
  "celltoolbar": "Tags",
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
